import gym
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy as dcp
from time import sleep
from mpl_toolkits.mplot3d import Axes3D
from sklearn import linear_model as lin
from sklearn.preprocessing import PolynomialFeatures as Poly

'''
usage:
    get policy by sarsa method

arguments:
    env: return from package gym.make()
    track_branches: number of track samples
    policy_steps: the max step of agent will to try
    epsilon: epsilon-greed method argument
    update_step: Temporal Difference(TD) update step width argument
    discount: gamma-discount cumulative reward argument
    policy: initial policy, if it is None, it will be set to follows the uniform distribution

returns:
    policy: policy of env besed on track samples
    Q_table: state-action cumulative reward besed on track samples
    track_branches: number of track samples
'''
def linear_approxi_sarsa(env, max_episodes=100, epsilon=0.1, discount=0.9, memory_size=2000, 
        minibatch=32, begin_step=1000, update_step=5, update_target_step=200, poly=3):

    step = 0
    state_n = env.observation_space.shape[0]
    action_n = env.action_space.n
    samples = np.zeros([memory_size, state_n * 2 + 3])
    #samples = np.zeros([memory_size, 7])
    Q_xita = []
    Q_xita_target = []

    def update_Qxita(sample_batch, Qxita, Qxita_target):
        
        states = sample_batch[:, :state_n]
        actions = sample_batch[:, state_n].reshape(-1, 1)
        rewards = sample_batch[:, state_n+1].reshape(-1, 1)
        next_states = sample_batch[:, -state_n:]
        next_actions = sample_batch[:, -1].reshape(-1, 1)
        #next_actions = predict_action(next_states, Qxita_target)
        
        sample_x = np.concatenate((states, actions), axis=1)
        x = Poly(degree=poly).fit_transform(sample_x)
        y = rewards + discount*Q_approxi(next_states, next_actions, Qxita_target).reshape(-1, 1)

        # Q_table's approximation function is polyfit(x, y)
        Qxita = lin.Ridge(alpha=0.5, fit_intercept=True, max_iter=1000).fit(x, y)
        return Qxita

    def Q_approxi(state, action, Qxita):
        if Qxita == []:
            return np.zeros(np.size(action))
        sample_x = np.concatenate((state, action), axis=1)
        x = Poly(degree=poly).fit_transform(sample_x)
        return Qxita.predict(x)

    # choose next action by argmax(a)[predict(x, a)]
    def predict_action(state, Qxita):
        if Qxita == []:
            return [env.action_space.sample()]
        action = []
        for i in range(state.shape[0]):
            actions = np.arange(env.action_space.n).reshape(-1, 1)
            states = np.vstack((state, state))
            Q_value = Q_approxi(states, actions, Qxita)
            max_q = np.max(Q_value)
            max_indexes = np.where(Q_value==max_q)[0]
            action.append(np.random.choice(max_indexes))
        return action

    pos_threshold = env.observation_space.high[0] / 2
    angle_threshold = env.observation_space.high[2] / 2

    for episode in range(max_episodes):
        # first state
        state = env.reset()
        action = np.random.choice(env.action_space.n)
        Reward = 0

        # begin a episode
        while True:
            env.render()
            # proceed epsilon-greed policy's action
            next_state, reward, done, _ = env.step(action)
            Reward += reward

            # next action generated by epsilon-greed policy
            if np.random.random() < epsilon:
                next_action = env.action_space.sample()
            else:
                next_action = predict_action(next_state.reshape(1, -1), Q_xita)[0]

            x, x_dot, theta, theta_dot = next_state
            r1 = (pos_threshold - abs(x)) / pos_threshold - 0.8
            r2 = (angle_threshold - abs(theta)) / angle_threshold - 0.5
            reward = r1 + r2

            # save sample if over memory_size, recover it from index 0
            samples[step%memory_size, :] = np.hstack((state, action, reward, next_state, next_action))

            # begin to learn
            if step >= begin_step:
                random_indexes = np.random.randint(min(step, memory_size), size=[minibatch, 1])
                sample_batch = samples[random_indexes, :].reshape(-1, state_n * 2 + 3)
                Q_xita = update_Qxita(sample_batch, Q_xita, Q_xita_target)

                if step % update_target_step == 0:
                    Q_xita_target = dcp(Q_xita)

            step += 1
            if done:
                break
            # update state and action
            state = next_state
            action = next_action
        print("episode: {:<5}, reward: {:<10}, average reward: {:<20} in {} steps, epsilon is {}".format(episode+1, Reward, step / (episode+1), step, epsilon))

    fig = plt.figure(1)
    ax = fig.gca(projection='3d')
    ax.scatter(samples[:, 0], samples[:, state_n], samples[:, state_n+1], marker='.')
    ax.scatter(samples[:, 0], samples[:, state_n], Q_approxi(samples[:, :state_n], samples[:, state_n].reshape(-1, 1), Q_xita), marker='^', color='r')
    ax.set_zlim(-1.5, 1.5)
    ax.set_xlabel("state")
    ax.set_ylabel("action")
    ax.set_zlabel("reward")
    fig.suptitle("avg_reward: {}    step: {}".format(step/max_episodes, step))

    plt.show()

if __name__ == '__main__':
    env = gym.make('CartPole-v1')
    linear_approxi_sarsa(env)
    env.close()
